import torch
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from alg_PPO_Discrete_big import PPO


# ! git clone https://github.com/boyu-ai/ma-gym.git
import sys
# sys.path.append("F:\ma-gym\ma_gym")
sys.path.append("..\ma-gym\ma_gym")
from envs.combat.combat import Combat


actor_lr = 3e-4
critic_lr = 1e-3
num_episodes = 100000
hidden_dim = 64
discount_fac = 0.99
epochs = 1
para_GAE_lmbda = 0.97
para_PPO_clip = 0.2
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


# 在二维的格子世界上进行的两个队伍的对战模拟游戏
# 每个智能体的动作集合为：向四周移动格，攻击周围格范围内其他敌对智能体，或者不采取任何行动。
# 起初每个智能体有 3 点生命值
# 如果智能体在敌人的攻击范围内被攻击到了，则会扣 1 生命值
# 生命值掉为 0 则死亡，最后存活的队伍获胜。
# 每个智能体的攻击有一轮的冷却时间。
# 在游戏中，我们能够控制一个队伍的所有智能体与另一个队伍的智能体对战。
# 另一个队伍的智能体使用固定的算法：攻击在范围内最近的敌人，如果攻击范围内没有敌人，则向敌人靠近。
team_size = 2
grid_size = (15, 15)
#创建Combat环境，格子世界的大小为15x15，己方智能体和敌方智能体数量都为2
env = Combat(grid_shape=grid_size, n_agents=team_size, n_opponents=team_size)


# 在训练时使用了参数共享（parameter sharing）的技巧，即对于所有智能体使用同一套策略参数
# 这样做的好处是能够使得模型训练数据更多，同时训练更稳定。
# 能够这样做的前提是，两个智能体是同质的（homogeneous），即它们的状态空间和动作空间是完全一致的，并且它们的优化目标也完全一致。
# 感兴趣的读者也可以自行实现非参数共享版本的 IPPO，此时每个智能体就是一个独立的 PPO 的实例。
# 和之前的一些实验不同，这里不再展示智能体获得的回报，而是将 IPPO 训练的两个智能体团队的胜率作为主要的实验结果。
state_dim = env.observation_space[0].shape[0]
action_dim = env.action_space[0].n
#两个智能体共享同一个策略
agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, epochs, para_GAE_lmbda, para_PPO_clip, discount_fac, device)

win_list = []
for i in range(10):
    with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:
        for i_episode in range(int(num_episodes / 10)):
            transition_dict_1 = {
                'states': [],
                'actions': [],
                'next_states': [],
                'rewards': [],
                'dones': []
            }
            transition_dict_2 = {
                'states': [],
                'actions': [],
                'next_states': [],
                'rewards': [],
                'dones': []
            }
            s = env.reset()
            terminal = False
            while not terminal:
                a_1 = agent.take_action(s[0])
                a_2 = agent.take_action(s[1])
                next_s, r, done, info = env.step([a_1, a_2])
                
                transition_dict_1['states'].append(s[0])
                transition_dict_1['actions'].append(a_1)
                transition_dict_1['next_states'].append(next_s[0])
                transition_dict_1['rewards'].append(r[0] + 100 if info['win'] else r[0] - 0.1)
                transition_dict_1['dones'].append(False)
                
                transition_dict_2['states'].append(s[1])
                transition_dict_2['actions'].append(a_2)
                transition_dict_2['next_states'].append(next_s[1])
                transition_dict_2['rewards'].append(r[1] + 100 if info['win'] else r[1] - 0.1)
                transition_dict_2['dones'].append(False)
                
                s = next_s
                terminal = all(done)
            win_list.append(1 if info["win"] else 0)
            agent.update(transition_dict_1)
            agent.update(transition_dict_2)
            if (i_episode + 1) % 100 == 0:
                pbar.set_postfix({
                    'episode':'%d' % (num_episodes / 10 * i + i_episode + 1),
                    'return':'%.3f' % np.mean(win_list[-100:])
                })
            pbar.update(1)


win_array = np.array(win_list)
#每100条轨迹取一次平均
win_array = np.mean(win_array.reshape(-1, 100), axis=1)

episodes_list = np.arange(win_array.shape[0]) * 100
plt.plot(episodes_list, win_array)
plt.xlabel('Episodes')
plt.ylabel('Win rate')
plt.title('Independent PPO on Combat')
plt.show()











